df <- readRDS("../data/models/social-risk-crash-rate-data.rds") %>%
select(-geoid)
glimpse(df)
Rows: 13,518
Columns: 44
$ year <int> 2018, 2019, 2020, 2021, 2022, 2023, 2018, 2…
$ total_population <int> 51349, 50967, 49194, 52308, 59590, 63478, 7…
$ pct_male_population <dbl> 12.460865, 11.694881, 12.229106, 13.550457,…
$ pct_female_population <dbl> 10.846132, 11.629654, 11.054169, 9.837846, …
$ pct_white_population <dbl> 4.1225202, 3.6815086, 2.5856754, 2.3610065,…
$ pct_black_population <dbl> 2.435429, 2.630197, 2.833673, 2.412624, 2.4…
$ pct_asian_population <dbl> 0.19803330, 0.14388387, 0.23376782, 0.30587…
$ pct_hispanic_population <dbl> 6.411328, 6.607147, 5.982424, 6.079353, 5.2…
$ pct_foreign_born <dbl> 2.909566, 2.442189, 2.187254, 2.200420, 3.0…
$ pct_age_under_18 <dbl> 2.130762, 2.643626, 2.333613, 2.387771, 1.7…
$ pct_age_18_34 <dbl> 1.715654, 1.653705, 1.882339, 1.963363, 2.0…
$ pct_age_35_64 <dbl> 3.296112, 3.079115, 3.041014, 3.043500, 3.3…
$ pct_age_65_plus <dbl> 1.80895804, 1.78607845, 1.56929356, 1.57910…
$ median_income <dbl> 58582.658, 49964.513, 68000.000, 70867.000,…
$ pct_income_under_25k <dbl> 0.5445916, 0.5544325, 0.5325841, 0.5142597,…
$ pct_income_25k_75k <dbl> 1.0568123, 1.1376418, 0.9533662, 0.8774915,…
$ pct_income_75k_plus <dbl> 0.9273290, 0.8824877, 1.2379531, 1.2693994,…
$ pct_below_poverty <dbl> 1.9574830, 1.9510653, 1.8091597, 1.9385106,…
$ median_gross_rent <dbl> 1579.1133, 1524.3577, 1701.0000, 1740.0000,…
$ pct_owner_occupied <dbl> 1.33318303, 1.35699756, 1.58439699, 1.46745…
$ pct_renter_occupied <dbl> 1.2085934, 1.2305140, 1.1518859, 1.2057574,…
$ pct_no_vehicle <dbl> 0.5122207, 0.6522735, 0.6118619, 0.6270527,…
$ pct_less_than_hs <dbl> 1.4509748, 1.1951954, 1.1302166, 1.0495486,…
$ pct_hs_diploma <dbl> 1.5576081, 1.4196542, 1.2704773, 1.3841042,…
$ pct_some_college <dbl> 0.8473540, 0.8997538, 0.7256966, 0.9348439,…
$ pct_associates_degree <dbl> 0.4912749, 0.3280552, 0.3923234, 0.3230851,…
$ pct_bachelors_degree <dbl> 0.9482748, 0.9457966, 0.8537607, 0.9023442,…
$ pct_graduate_degree <dbl> 0.3998749, 0.7098271, 1.1037907, 0.9673436,…
$ pct_in_labor_force <dbl> 3.566504, 3.430191, 3.555304, 3.595995, 4.0…
$ pct_not_in_labor_force <dbl> 3.448445, 3.326595, 3.162980, 3.106587, 2.8…
$ unemployment_rate <dbl> 15.750133, 13.478747, 10.806175, 11.536417,…
$ pct_commute_short <dbl> 0.7350082, 0.4604284, 0.1585556, 0.4167607,…
$ pct_commute_medium <dbl> 1.7061331, 1.6882374, 1.2521824, 1.2579290,…
$ pct_commute_long <dbl> 2.477320, 2.685832, 3.553271, 3.737464, 4.6…
$ pct_carpool <dbl> 0.35798327, 0.31462606, 0.00000000, 0.04970…
$ pct_public_transit <dbl> 2.056500, 2.540030, 2.898721, 2.366742, 2.8…
$ pct_walk <dbl> 0.37702494, 0.30695226, 0.13009688, 0.76469…
$ pct_bike <dbl> 0.00000000, 0.00000000, 0.00000000, 0.00000…
$ pct_work_from_home <dbl> 0.01713750, 0.04604284, 0.04675356, 0.05161…
$ crash_rate_per_1000 <dbl> 0.8373993, 0.5886150, 0.6301567, 0.5735238,…
$ injury_rate_per_1000 <dbl> 0.2142184, 0.1569640, 0.2439316, 0.2294095,…
$ fatality_rate_per_1000 <dbl> 0.00000000, 0.00000000, 0.00000000, 0.00000…
$ pct_vehicle <dbl> 2.0165122, 1.9222885, 2.1120415, 2.0340979,…
$ borough <chr> "Bronx", "Bronx", "Bronx", "Bronx", "Bronx"…
df <- df %>%
mutate(
post_pandemic = ifelse(year < 2020, "pre", "post") # pre vs. post 2020
) %>%
mutate(post_pandemic = as.integer(post_pandemic == "post"))
glimpse(df)
Rows: 13,518
Columns: 45
$ year <int> 2018, 2019, 2020, 2021, 2022, 2023, 2018, 2…
$ total_population <int> 51349, 50967, 49194, 52308, 59590, 63478, 7…
$ pct_male_population <dbl> 12.460865, 11.694881, 12.229106, 13.550457,…
$ pct_female_population <dbl> 10.846132, 11.629654, 11.054169, 9.837846, …
$ pct_white_population <dbl> 4.1225202, 3.6815086, 2.5856754, 2.3610065,…
$ pct_black_population <dbl> 2.435429, 2.630197, 2.833673, 2.412624, 2.4…
$ pct_asian_population <dbl> 0.19803330, 0.14388387, 0.23376782, 0.30587…
$ pct_hispanic_population <dbl> 6.411328, 6.607147, 5.982424, 6.079353, 5.2…
$ pct_foreign_born <dbl> 2.909566, 2.442189, 2.187254, 2.200420, 3.0…
$ pct_age_under_18 <dbl> 2.130762, 2.643626, 2.333613, 2.387771, 1.7…
$ pct_age_18_34 <dbl> 1.715654, 1.653705, 1.882339, 1.963363, 2.0…
$ pct_age_35_64 <dbl> 3.296112, 3.079115, 3.041014, 3.043500, 3.3…
$ pct_age_65_plus <dbl> 1.80895804, 1.78607845, 1.56929356, 1.57910…
$ median_income <dbl> 58582.658, 49964.513, 68000.000, 70867.000,…
$ pct_income_under_25k <dbl> 0.5445916, 0.5544325, 0.5325841, 0.5142597,…
$ pct_income_25k_75k <dbl> 1.0568123, 1.1376418, 0.9533662, 0.8774915,…
$ pct_income_75k_plus <dbl> 0.9273290, 0.8824877, 1.2379531, 1.2693994,…
$ pct_below_poverty <dbl> 1.9574830, 1.9510653, 1.8091597, 1.9385106,…
$ median_gross_rent <dbl> 1579.1133, 1524.3577, 1701.0000, 1740.0000,…
$ pct_owner_occupied <dbl> 1.33318303, 1.35699756, 1.58439699, 1.46745…
$ pct_renter_occupied <dbl> 1.2085934, 1.2305140, 1.1518859, 1.2057574,…
$ pct_no_vehicle <dbl> 0.5122207, 0.6522735, 0.6118619, 0.6270527,…
$ pct_less_than_hs <dbl> 1.4509748, 1.1951954, 1.1302166, 1.0495486,…
$ pct_hs_diploma <dbl> 1.5576081, 1.4196542, 1.2704773, 1.3841042,…
$ pct_some_college <dbl> 0.8473540, 0.8997538, 0.7256966, 0.9348439,…
$ pct_associates_degree <dbl> 0.4912749, 0.3280552, 0.3923234, 0.3230851,…
$ pct_bachelors_degree <dbl> 0.9482748, 0.9457966, 0.8537607, 0.9023442,…
$ pct_graduate_degree <dbl> 0.3998749, 0.7098271, 1.1037907, 0.9673436,…
$ pct_in_labor_force <dbl> 3.566504, 3.430191, 3.555304, 3.595995, 4.0…
$ pct_not_in_labor_force <dbl> 3.448445, 3.326595, 3.162980, 3.106587, 2.8…
$ unemployment_rate <dbl> 15.750133, 13.478747, 10.806175, 11.536417,…
$ pct_commute_short <dbl> 0.7350082, 0.4604284, 0.1585556, 0.4167607,…
$ pct_commute_medium <dbl> 1.7061331, 1.6882374, 1.2521824, 1.2579290,…
$ pct_commute_long <dbl> 2.477320, 2.685832, 3.553271, 3.737464, 4.6…
$ pct_carpool <dbl> 0.35798327, 0.31462606, 0.00000000, 0.04970…
$ pct_public_transit <dbl> 2.056500, 2.540030, 2.898721, 2.366742, 2.8…
$ pct_walk <dbl> 0.37702494, 0.30695226, 0.13009688, 0.76469…
$ pct_bike <dbl> 0.00000000, 0.00000000, 0.00000000, 0.00000…
$ pct_work_from_home <dbl> 0.01713750, 0.04604284, 0.04675356, 0.05161…
$ crash_rate_per_1000 <dbl> 0.8373993, 0.5886150, 0.6301567, 0.5735238,…
$ injury_rate_per_1000 <dbl> 0.2142184, 0.1569640, 0.2439316, 0.2294095,…
$ fatality_rate_per_1000 <dbl> 0.00000000, 0.00000000, 0.00000000, 0.00000…
$ pct_vehicle <dbl> 2.0165122, 1.9222885, 2.1120415, 2.0340979,…
$ borough <chr> "Bronx", "Bronx", "Bronx", "Bronx", "Bronx"…
$ post_pandemic <int> 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1…
df <- df %>%
mutate(
car_density_interaction = pct_vehicle * log1p(total_population),
income_vehicle_interaction = median_income * pct_vehicle
)
YEARI will treat year as a factor and one-hot encode it.
df$year <- as.factor(df$year)
year_dummies <- model.matrix(~ year - 1, data = df)
df <- cbind(df[ , !(names(df) %in% c("year"))], year_dummies)
We will remove all possible target variables and keep only one per model training.
# Choose your target variable (e.g., crash rate per 1,000 residents)
target_var <- "crash_rate_per_1000"
# Remove all target variables except selected
cols_to_remove <- grep("per_1000",
names(df),
value = TRUE)
cols_to_remove <- setdiff(cols_to_remove, target_var) # keep this column
df <- df %>% select(-all_of(cols_to_remove),)
# Create feature matrix and target vector
X <- df %>% select(-target_var, -borough, -total_population)
y <- df[[target_var]]
What This Does - Uses R’s
xgboost::xgb.cv() to evaluate each parameter set.
- Optuna (Python) handles the search space and Bayesian optimization. -
The final best parameters are applied to fit the
final_model. - Search space: Instead of
predefined grids, trial$suggest_float() and
trial$suggest_int() explore a range of values. -
Best parameters: study$best_params holds
the optimal hyperparameters.
## CONVERT TO DMATRIX
dtrain_all <- xgb.DMatrix(data = as.matrix(X), label = y)
## Start Python venv
reticulate::use_virtualenv("r-reticulate", required = TRUE)
## OPTUNA-BASED SPATIAL CV
optuna <- import("optuna")
boroughs <- unique(df$borough)
folds <- lapply(boroughs, function(b) which(df$borough != b))
# Optuna objective
objective <- function(trial) {
params <- list(
booster = "gbtree",
eta = trial$suggest_float("eta", 0.01, 0.3, log = TRUE),
max_depth = trial$suggest_int("max_depth", 3, 12),
min_child_weight = trial$suggest_int("min_child_weight", 1, 10),
subsample = trial$suggest_float("subsample", 0.5, 1.0),
colsample_bytree = trial$suggest_float("colsample_bytree", 0.5, 1.0),
gamma = trial$suggest_float("gamma", 0, 10),
lambda = trial$suggest_float("lambda", 0, 10),
alpha = trial$suggest_float("alpha", 0, 10)
)
rmse_scores <- numeric(length(folds))
for (i in seq_along(folds)) {
train_idx <- folds[[i]]
valid_idx <- setdiff(seq_len(nrow(dtrain_all)), train_idx)
dtrain <- xgb.DMatrix(data = as.matrix(X[train_idx, ]), label = y[train_idx])
dvalid <- xgb.DMatrix(data = as.matrix(X[valid_idx, ]), label = y[valid_idx])
model <- xgb.train(
params = params,
data = dtrain,
nrounds = 500,
watchlist = list(val = dvalid),
early_stopping_rounds = 20,
verbose = 0
)
rmse_scores[i] <- min(model$evaluation_log$val_rmse)
}
preds <- predict(model, as.matrix(X[valid_idx, ]))
return(Metrics::rmse(y[valid_idx], preds))
}
# Run Optuna study
set.seed(2025)
study <- optuna$create_study(direction = "minimize")
study$optimize(objective, n_trials = 50)
best_params <- study$best_params
print(best_params)
$eta
[1] 0.1201187
$max_depth
[1] 5
$min_child_weight
[1] 8
$subsample
[1] 0.6164199
$colsample_bytree
[1] 0.6999922
$gamma
[1] 3.719088
$lambda
[1] 5.17341
$alpha
[1] 3.270779
# Set seed
set.seed(2025)
# Split by index
train_index <- createDataPartition(y, p = 0.8, list = FALSE)
X_train <- X[train_index, ]
y_train <- y[train_index]
X_test <- X[-train_index, ]
y_test <- y[-train_index]
# Convert to xgb.DMatrix
dtrain <- xgb.DMatrix(data = as.matrix(X_train), label = y_train)
dtest <- xgb.DMatrix(data = as.matrix(X_test), label = y_test)
# Set seed
set.seed(2025)
# Training with parallel processing
final_model <- xgb.train(
params = list(
eta = best_params$eta,
max_depth = best_params$max_depth,
gamma = best_params$gamma,
colsample_bytree = best_params$colsample_bytree,
min_child_weight = best_params$min_child_weight,
subsample = best_params$subsample,
objective = "reg:squarederror",
eval_metric = "rmse"
),
data = dtrain,
nrounds = 1000,
watchlist = list(train = dtrain, test = dtest),
early_stopping_rounds = 20,
verbose = 1,
nthread = detectCores() - 1
)
[1] train-rmse:1.959935 test-rmse:1.943103
Multiple eval metrics are present. Will use test_rmse for early stopping.
Will train until test_rmse hasn't improved in 20 rounds.
[2] train-rmse:1.885766 test-rmse:1.870772
[3] train-rmse:1.826508 test-rmse:1.818521
[4] train-rmse:1.786537 test-rmse:1.782901
[5] train-rmse:1.741575 test-rmse:1.746282
[6] train-rmse:1.701737 test-rmse:1.715650
[7] train-rmse:1.675409 test-rmse:1.692276
[8] train-rmse:1.649850 test-rmse:1.670641
[9] train-rmse:1.628597 test-rmse:1.661646
[10] train-rmse:1.606858 test-rmse:1.642966
[11] train-rmse:1.597939 test-rmse:1.636465
[12] train-rmse:1.578917 test-rmse:1.631226
[13] train-rmse:1.564573 test-rmse:1.621342
[14] train-rmse:1.549118 test-rmse:1.614985
[15] train-rmse:1.539681 test-rmse:1.610942
[16] train-rmse:1.523881 test-rmse:1.601895
[17] train-rmse:1.513709 test-rmse:1.593075
[18] train-rmse:1.507580 test-rmse:1.586184
[19] train-rmse:1.500953 test-rmse:1.586513
[20] train-rmse:1.493454 test-rmse:1.584122
[21] train-rmse:1.490005 test-rmse:1.581247
[22] train-rmse:1.478843 test-rmse:1.575141
[23] train-rmse:1.476069 test-rmse:1.573897
[24] train-rmse:1.470877 test-rmse:1.570920
[25] train-rmse:1.466444 test-rmse:1.570463
[26] train-rmse:1.459459 test-rmse:1.568340
[27] train-rmse:1.451448 test-rmse:1.567919
[28] train-rmse:1.447033 test-rmse:1.569194
[29] train-rmse:1.439085 test-rmse:1.566178
[30] train-rmse:1.435631 test-rmse:1.564914
[31] train-rmse:1.433212 test-rmse:1.564180
[32] train-rmse:1.427243 test-rmse:1.561112
[33] train-rmse:1.422343 test-rmse:1.560981
[34] train-rmse:1.418735 test-rmse:1.561259
[35] train-rmse:1.414467 test-rmse:1.559647
[36] train-rmse:1.409145 test-rmse:1.558684
[37] train-rmse:1.402428 test-rmse:1.556404
[38] train-rmse:1.397734 test-rmse:1.554198
[39] train-rmse:1.396167 test-rmse:1.553283
[40] train-rmse:1.391987 test-rmse:1.552270
[41] train-rmse:1.387588 test-rmse:1.552181
[42] train-rmse:1.384995 test-rmse:1.552003
[43] train-rmse:1.378946 test-rmse:1.550981
[44] train-rmse:1.375850 test-rmse:1.553506
[45] train-rmse:1.372297 test-rmse:1.553599
[46] train-rmse:1.368335 test-rmse:1.553950
[47] train-rmse:1.364855 test-rmse:1.551612
[48] train-rmse:1.359726 test-rmse:1.551580
[49] train-rmse:1.358148 test-rmse:1.550497
[50] train-rmse:1.354803 test-rmse:1.550000
[51] train-rmse:1.350247 test-rmse:1.550253
[52] train-rmse:1.347000 test-rmse:1.552812
[53] train-rmse:1.344932 test-rmse:1.551698
[54] train-rmse:1.340769 test-rmse:1.551204
[55] train-rmse:1.339520 test-rmse:1.551740
[56] train-rmse:1.335570 test-rmse:1.546184
[57] train-rmse:1.329665 test-rmse:1.550643
[58] train-rmse:1.326548 test-rmse:1.552036
[59] train-rmse:1.324028 test-rmse:1.549153
[60] train-rmse:1.320261 test-rmse:1.549236
[61] train-rmse:1.315158 test-rmse:1.547558
[62] train-rmse:1.310720 test-rmse:1.545829
[63] train-rmse:1.304922 test-rmse:1.545195
[64] train-rmse:1.298181 test-rmse:1.542381
[65] train-rmse:1.296212 test-rmse:1.542564
[66] train-rmse:1.292331 test-rmse:1.543948
[67] train-rmse:1.288435 test-rmse:1.543639
[68] train-rmse:1.286875 test-rmse:1.541403
[69] train-rmse:1.283858 test-rmse:1.540771
[70] train-rmse:1.278343 test-rmse:1.536453
[71] train-rmse:1.276273 test-rmse:1.535691
[72] train-rmse:1.273783 test-rmse:1.534360
[73] train-rmse:1.270261 test-rmse:1.532495
[74] train-rmse:1.266530 test-rmse:1.532205
[75] train-rmse:1.265671 test-rmse:1.531939
[76] train-rmse:1.263510 test-rmse:1.531678
[77] train-rmse:1.261402 test-rmse:1.532488
[78] train-rmse:1.259280 test-rmse:1.532332
[79] train-rmse:1.257066 test-rmse:1.534984
[80] train-rmse:1.254124 test-rmse:1.534008
[81] train-rmse:1.252545 test-rmse:1.533317
[82] train-rmse:1.246175 test-rmse:1.533971
[83] train-rmse:1.244708 test-rmse:1.532679
[84] train-rmse:1.238223 test-rmse:1.533592
[85] train-rmse:1.234989 test-rmse:1.531914
[86] train-rmse:1.232076 test-rmse:1.529619
[87] train-rmse:1.229449 test-rmse:1.528281
[88] train-rmse:1.227759 test-rmse:1.527017
[89] train-rmse:1.227075 test-rmse:1.528388
[90] train-rmse:1.224285 test-rmse:1.530681
[91] train-rmse:1.222168 test-rmse:1.530990
[92] train-rmse:1.216858 test-rmse:1.531737
[93] train-rmse:1.215050 test-rmse:1.531465
[94] train-rmse:1.214020 test-rmse:1.531490
[95] train-rmse:1.213599 test-rmse:1.530946
[96] train-rmse:1.211771 test-rmse:1.530869
[97] train-rmse:1.209781 test-rmse:1.532405
[98] train-rmse:1.208736 test-rmse:1.533464
[99] train-rmse:1.208342 test-rmse:1.533837
[100] train-rmse:1.203152 test-rmse:1.535616
[101] train-rmse:1.200668 test-rmse:1.535577
[102] train-rmse:1.198942 test-rmse:1.535120
[103] train-rmse:1.196072 test-rmse:1.536238
[104] train-rmse:1.192535 test-rmse:1.533738
[105] train-rmse:1.189005 test-rmse:1.531056
[106] train-rmse:1.187788 test-rmse:1.529239
[107] train-rmse:1.183801 test-rmse:1.528171
[108] train-rmse:1.183145 test-rmse:1.528352
Stopping. Best iteration:
[88] train-rmse:1.227759 test-rmse:1.527017
# Create directory if it doesn't exist
if (!dir.exists("../data/models")) {
dir.create("../data/models", recursive = TRUE)
}
# Save the final XGBoost model
saveRDS(final_model, file = "../data/models/spatial_cv_model.rds")
# Save the best parameters
saveRDS(best_params, file = "../data/models/spatial_cv_best_params.rds")
cat("Model and parameters saved to ../data/models/")
Model and parameters saved to ../data/models/
library(Metrics)
library(ggplot2)
library(dplyr)
set.seed(2025)
# Predict on test set
preds <- predict(final_model, as.matrix(X_test))
# --- Metrics ---
rmse <- sqrt(mean((y_test - preds)^2))
mae <- mean(abs(y_test - preds))
mape <- mean(abs((y_test - preds) / y_test)) * 100
r2 <- 1 - (sum((y_test - preds)^2) / sum((y_test - mean(y_test))^2))
cat("Model Evaluation Metrics:\n")
cat(" RMSE:", rmse, "\n")
cat(" MAE :", mae, "\n")
cat(" MAPE:", mape, "%\n")
cat(" R² :", r2, "\n\n")
# --- Residuals ---
residuals <- y_test - preds
residual_df <- data.frame(
actual = y_test,
predicted = preds,
residuals = residuals
)
# --- Plot: Predicted vs Actual ---
p1 <- residual_df %>%
ggplot(aes(x = actual, y = predicted)) +
geom_point(alpha = 0.5) +
geom_abline(slope = 1, intercept = 0, color = "red") +
theme_minimal() +
labs(title = "Predicted vs Actual Crash Rates",
x = "Actual",
y = "Predicted")
# --- Plot: Residuals vs Predicted ---
p2 <- residual_df %>%
ggplot(aes(x = predicted, y = residuals)) +
geom_point(alpha = 0.5, color = "blue") +
geom_hline(yintercept = 0, linetype = "dashed", color = "red") +
theme_minimal() +
labs(title = "Residuals vs Predicted",
x = "Predicted",
y = "Residuals")
# --- Plot: Residual Density ---
p3 <- residual_df %>%
ggplot(aes(x = residuals)) +
geom_histogram(aes(y = ..density..), bins = 30, fill = "skyblue", alpha = 0.7) +
geom_density(color = "red") +
theme_minimal() +
labs(title = "Residual Distribution",
x = "Residuals",
y = "Density")
# Print plots
print(p1)
print(p2)
print(p3)
# Compute SHAP values
shap_values <- shap.values(xgb_model = final_model, X_train = as.matrix(X_train))
shap_long <- shap.prep(shap_contrib = shap_values$shap_score, X_train = as.matrix(X_train))
# SHAP summary plot
print(shap.plot.summary(shap_long))
if (!dir.exists("../report/plots")) {
dir.create("../report/plots")
}
png("../report/plots/spatial_cv_shap_summary_plot.png", width = 1200, height = 800)
shap.plot.summary(shap_long)
dev.off()
quartz_off_screen
2
xgb.plot.tree(model = final_model, trees = 0)
xgb.plot.tree(model = final_model, trees = 1)
xgb.plot.tree(model = final_model, trees = 2)
xgb.plot.multi.trees(model = final_model)
# ============================================================
# Additional Model Diagnostics and Deeper Analysis
# ============================================================
library(ggplot2)
library(dplyr)
library(pdp) # For Partial Dependence Plots
library(DALEX) # For model explainability
library(ggthemes)
library(sf)
# ---------------------------
# 1. SHAP Dependence and Interaction Plots
# ---------------------------
message("\nGenerating SHAP dependence and interaction plots...")
# Assuming shap_values and shap_long are already computed
# (if not, recompute them using iml or SHAPforxgboost packages)
# Top feature by SHAP importance
top_feature <- shap_long %>%
as_tibble() %>%
count(variable, wt = abs(value), sort = TRUE) %>%
dplyr::slice(1) %>%
pull(variable)
# Dependence plot for top feature
shap.plot.dependence(data_long = shap_long, x = top_feature, color_feature = top_feature)
# Interaction values
shap_interaction_values <- predict(
final_model,
as.matrix(X_train),
predinteraction = TRUE
)
# shap_interaction_values will be a 3D array: [n_samples, n_features, n_features]
dim(shap_interaction_values)
[1] 10816 48 48
# ---------------------------
# 2. Residual Plots and Mapping
# ---------------------------
message("\nComputing residuals and creating residual plots...")
preds <- predict(final_model, as.matrix(X_test))
residuals <- y_test - preds
residual_df <- data.frame(
observed = y_test,
predicted = preds,
residual = residuals
)
# Predicted vs Observed
ggplot(residual_df, aes(x = predicted, y = observed)) +
geom_point(alpha = 0.6) +
geom_abline(slope = 1, intercept = 0, color = "red") +
theme_minimal() +
labs(title = "Predicted vs. Observed Crash Rates",
x = "Predicted Crash Rate per 1,000",
y = "Observed Crash Rate per 1,000")
# Residual Histogram
ggplot(residual_df, aes(x = residual)) +
geom_histogram(binwidth = 0.2, fill = "steelblue", color = "white") +
theme_minimal() +
labs(title = "Residual Distribution", x = "Residuals", y = "Count")
# ---------------------------
# 3. Partial Dependence Plots (PDP)
# ---------------------------
message("\nGenerating Partial Dependence Plots...")
top_features <- shap_long %>%
count(variable, wt = abs(value), sort = TRUE) %>%
dplyr::slice(1:10) %>%
pull(variable)
for (f in top_features) {
pd <- partial(final_model, pred.var = f, train = as.matrix(X_train), grid.resolution = 30)
plot(pd, main = paste("Partial Dependence of", f))
}